Biostat 203B Homework 5

Due Mar 22 @ 11:59PM

Author

Yingxin Zhang, UID: 006140202

Display machine information:

sessionInfo()
R version 4.3.2 (2023-10-31)
Platform: aarch64-apple-darwin20 (64-bit)
Running under: macOS Monterey 12.7.3

Matrix products: default
BLAS:   /Library/Frameworks/R.framework/Versions/4.3-arm64/Resources/lib/libRblas.0.dylib 
LAPACK: /Library/Frameworks/R.framework/Versions/4.3-arm64/Resources/lib/libRlapack.dylib;  LAPACK version 3.11.0

locale:
[1] en_US.UTF-8/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8

time zone: America/Los_Angeles
tzcode source: internal

attached base packages:
[1] stats     graphics  grDevices utils     datasets  methods   base     

loaded via a namespace (and not attached):
 [1] htmlwidgets_1.6.4 compiler_4.3.2    fastmap_1.1.1     cli_3.6.1        
 [5] tools_4.3.2       htmltools_0.5.7   rstudioapi_0.15.0 yaml_2.3.7       
 [9] rmarkdown_2.25    knitr_1.45        jsonlite_1.8.7    xfun_0.41        
[13] digest_0.6.33     rlang_1.1.1       evaluate_0.23    

Display my machine memory.

memuse::Sys.meminfo()
Totalram:  16.000 GiB 
Freeram:   77.312 MiB 

Load database libraries and the tidyverse frontend:

library(dbplyr)
library(DBI)
library(gt)
library(gtsummary)
library(tidyverse)
── Attaching core tidyverse packages ──────────────────────── tidyverse 2.0.0 ──
✔ dplyr     1.1.4     ✔ readr     2.1.4
✔ forcats   1.0.0     ✔ stringr   1.5.0
✔ ggplot2   3.4.4     ✔ tibble    3.2.1
✔ lubridate 1.9.3     ✔ tidyr     1.3.0
✔ purrr     1.0.2     
── Conflicts ────────────────────────────────────────── tidyverse_conflicts() ──
✖ dplyr::filter() masks stats::filter()
✖ dplyr::ident()  masks dbplyr::ident()
✖ dplyr::lag()    masks stats::lag()
✖ dplyr::sql()    masks dbplyr::sql()
ℹ Use the conflicted package (<http://conflicted.r-lib.org/>) to force all conflicts to become errors
library(lubridate)
library(miceRanger)
library(GGally)
Registered S3 method overwritten by 'GGally':
  method from   
  +.gg   ggplot2
library(tidymodels)
── Attaching packages ────────────────────────────────────── tidymodels 1.1.1 ──
✔ broom        1.0.5     ✔ rsample      1.2.0
✔ dials        1.2.0     ✔ tune         1.1.2
✔ infer        1.0.5     ✔ workflows    1.1.3
✔ modeldata    1.2.0     ✔ workflowsets 1.0.1
✔ parsnip      1.1.1     ✔ yardstick    1.2.0
✔ recipes      1.0.8     
── Conflicts ───────────────────────────────────────── tidymodels_conflicts() ──
✖ recipes::all_double()  masks gtsummary::all_double()
✖ recipes::all_factor()  masks gtsummary::all_factor()
✖ recipes::all_integer() masks gtsummary::all_integer()
✖ recipes::all_logical() masks gtsummary::all_logical()
✖ recipes::all_numeric() masks gtsummary::all_numeric()
✖ scales::discard()      masks purrr::discard()
✖ dplyr::filter()        masks stats::filter()
✖ recipes::fixed()       masks stringr::fixed()
✖ dplyr::ident()         masks dbplyr::ident()
✖ dplyr::lag()           masks stats::lag()
✖ yardstick::spec()      masks readr::spec()
✖ dplyr::sql()           masks dbplyr::sql()
✖ recipes::step()        masks stats::step()
• Dig deeper into tidy modeling with R at https://www.tmwr.org
library(keras)

Attaching package: 'keras'

The following object is masked from 'package:yardstick':

    get_weights
library(ranger)
library(stacks)
library(xgboost)

Attaching package: 'xgboost'

The following object is masked from 'package:dplyr':

    slice
library(ggplot2)

Predicting ICU duration

Using the ICU cohort mimiciv_icu_cohort.rds you built in Homework 4, develop at least three machine learning approaches (logistic regression with enet regularization, random forest, boosting, SVM, MLP, etc) plus a model stacking approach for predicting whether a patient’s ICU stay will be longer than 2 days. You should use the los_long variable as the outcome. You algorithms can use patient demographic information (gender, age at ICU intime, marital status, race), ICU admission information (first care unit), the last lab measurements before the ICU stay, and first vital measurements during ICU stay as features. You are welcome to use any feature engineering techniques you think are appropriate; but make sure to not use features that are not available at an ICU stay’s intime. For instance, last_careunit cannot be used in your algorithms.

Answer:

  • Notes: Here I also upload some rds files (ML models) to save your render time.

1. Data preprocessing and feature engineering

First, I load the ICU cohort mimic_icu_cohort.rds and select the features I need for the machine learning algorithms.

# define the variables I need
ctg_vars <- c("race", "insurance", "marital_status", "gender", 
              "first_careunit")

lab_vars <- c("Creatinine", "Potassium", "Sodium", "Chloride", "Bicarbonate", 
              "Hematocrit", "Glucose", "White_Blood_Cells")

vital_vars <- c("Heart_Rate", "Non_Invasive_Blood_Pressure_systolic", 
                "Non_Invasive_Blood_Pressure_diastolic", 
                "Temperature_Fahrenheit", "Respiratory_Rate")

all_vars <- c(ctg_vars, lab_vars, vital_vars, "los_long", "age_intime", 
              "subject_id", "hadm_id", "stay_id")

# import data from rds file
mimic_icu_cohort <- read_rds("../hw4/mimiciv_shiny/mimic_icu_cohort.rds") |>
  rename_with(~ gsub(" ", "_", .x)) |>
  select(all_of(all_vars)) |>
  as_tibble()

# convert the categorical variables to factor
mimic_icu_cohort <- mimic_icu_cohort |>
  mutate(across(all_of(ctg_vars), factor))

# summary of the dataset (table)
mimic_icu_cohort |>
  select(-subject_id, -hadm_id, -stay_id) |>
  tbl_summary(by = los_long)
Characteristic FALSE, N = 38,0501 TRUE, N = 35,1311
race

    ASIAN 1,148 (3.0%) 1,007 (2.9%)
    BLACK 4,311 (11%) 3,649 (10%)
    HISPANIC 1,492 (3.9%) 1,249 (3.6%)
    Other 5,160 (14%) 5,596 (16%)
    WHITE 25,939 (68%) 23,630 (67%)
insurance

    Medicaid 3,060 (8.0%) 2,468 (7.0%)
    Medicare 16,489 (43%) 16,602 (47%)
    Other 18,501 (49%) 16,061 (46%)
marital_status

    DIVORCED 2,843 (8.0%) 2,561 (7.9%)
    MARRIED 16,984 (48%) 15,784 (49%)
    SINGLE 11,175 (31%) 9,683 (30%)
    WIDOWED 4,699 (13%) 4,339 (13%)
    Unknown 2,349 2,764
gender

    F 17,014 (45%) 15,349 (44%)
    M 21,036 (55%) 19,782 (56%)
first_careunit

    Cardiac Vascular Intensive Care Unit (CVICU) 5,827 (15%) 5,755 (16%)
    Medical Intensive Care Unit (MICU) 8,782 (23%) 7,116 (20%)
    Medical/Surgical Intensive Care Unit (MICU/SICU) 7,147 (19%) 5,586 (16%)
    Surgical Intensive Care Unit (SICU) 5,654 (15%) 5,507 (16%)
    Other 10,640 (28%) 11,167 (32%)
Creatinine 1.00 (0.80, 1.40) 1.00 (0.80, 1.60)
    Unknown 2,599 3,171
Potassium 4.20 (3.80, 4.60) 4.20 (3.80, 4.70)
    Unknown 4,189 4,712
Sodium 139.0 (136.0, 141.0) 138.0 (135.0, 141.0)
    Unknown 4,174 4,698
Chloride 102 (98, 105) 102 (98, 105)
    Unknown 4,175 4,708
Bicarbonate 25.0 (22.0, 27.0) 24.0 (21.0, 27.0)
    Unknown 4,270 4,780
Hematocrit 36 (30, 40) 35 (29, 40)
    Unknown 2,236 2,781
Glucose 118 (98, 153) 122 (100, 159)
    Unknown 4,282 4,817
White_Blood_Cells 9.0 (6.6, 12.6) 9.7 (7.0, 13.8)
    Unknown 2,272 2,822
Heart_Rate 85 (74, 99) 88 (76, 103)
    Unknown 17 1
Non_Invasive_Blood_Pressure_systolic 122 (107, 139) 120 (104, 138)
    Unknown 732 247
Non_Invasive_Blood_Pressure_diastolic 68 (58, 80) 66 (55, 79)
    Unknown 733 250
Temperature_Fahrenheit 98.10 (97.60, 98.60) 98.20 (97.60, 98.80)
    Unknown 1,189 173
Respiratory_Rate 18.0 (15.0, 22.0) 19.0 (15.0, 23.0)
    Unknown 87 11
age_intime 65 (53, 77) 67 (56, 78)
1 n (%); Median (IQR)

Then, I check the missingness of the variables and visualize it as follows. From the plot, we can see that the missingness of the variables here is not severe as the most missingness is below 15%. Therefore, I will impute the missing values instead of removing the variables.

# display the missing distribution of the variables and visualize it
mimic_icu_cohort |>
  map_df(~sum(is.na(.x))/nrow(mimic_icu_cohort)) |>
  gather(variable, missing) |>
  filter(missing > 0) |>
  arrange(desc(missing)) |>
  ggplot(aes(x = reorder(variable, missing), y = missing)) +
  geom_col() +
  labs(title = "Missingness of Variables",
       x = "Variables",
       y = "Proportion of Missingness") +
  # display the value of the bar and approximate to 0.01%
  geom_text(aes(label = scales::percent(missing, accuracy = 0.01)), 
            vjust = -1, size = 3) +
  coord_cartesian(ylim = c(0, 0.15)) +
  theme(axis.text.x = element_text(angle = 45, hjust=1)) +
  theme(panel.background = element_rect(fill = "white"),
        panel.grid = element_line(color = "gray", linewidth = 0.2),
        panel.border = element_rect(fill = NA, linewidth = 0.5))

Before imputing the missing values, I convert the outliers of the continuous variables to the missing values.

# write a function to convert the outliers to missing values
outlier_to_na <- function(x) {
  q1 <- quantile(x, 0.25, na.rm = TRUE)
  q3 <- quantile(x, 0.75, na.rm = TRUE)
  iqr <- q3 - q1
  x[x < (q1 - 1.5 * iqr) | x > (q3 + 1.5 * iqr)] <- NA
  x
}

# convert the outliers (IQR method) to missing values for lab and vital events
mimic_icu_replace <- mimic_icu_cohort |>
  mutate(across(c(lab_vars, vital_vars), outlier_to_na))
Warning: There were 2 warnings in `mutate()`.
The first warning was:
ℹ In argument: `across(c(lab_vars, vital_vars), outlier_to_na)`.
Caused by warning:
! Using an external vector in selections was deprecated in tidyselect 1.1.0.
ℹ Please use `all_of()` or `any_of()` instead.
  # Was:
  data %>% select(lab_vars)

  # Now:
  data %>% select(all_of(lab_vars))

See <https://tidyselect.r-lib.org/reference/faq-external-vector.html>.
ℹ Run `dplyr::last_dplyr_warnings()` to see the 1 remaining warning.

Here, I use miceRanger to impute the missing values. To save render time, I check if the imputed dataset is already saved. If not, I impute the missing values using miceRanger and save the imputed dataset as mimic_icu_mice.rds.

if (file.exists("mimic_icu_mice.rds")) {
  mimic_icu_mice <- read_rds("mimic_icu_mice.rds")
} else {
  # impute the missing values usingmiceRanger
  seqTime <- system.time(
    mimic_icu_mice <- miceRanger(
      mimic_icu_replace, 
      m=3, 
      maxit=10,
      returnModels = FALSE, 
      verbose=TRUE
    )
  )
  mimic_icu_mice |>
    write_rds("mimic_icu_mice.rds")
}

To check the imputed values, I plot the distributions of the imputed variables (black) and compare with the original ones (red). From the plot, we can see that the imputed values are similar to the original ones, which means the imputation is successful.

# plot distributions of the imputed variables and compare with the original ones
plotDistributions(mimic_icu_mice, vars = 'allNumeric')

After imputing the missing values, I choose the first imputed dataset and convert non-numeric variables to factors and scale numeric variables.

# choose the first imputed dataset
mimic_icu_imputed <- completeData(mimic_icu_mice)[[1]]

# convert non-numeric variables to factors and scale numeric variables
mimic_icu_final <- mimic_icu_imputed |>
  mutate_if(is.character, as.factor) |>
  mutate_if(is.numeric, scale) |>
  mutate_if(is.logical, as.factor)

2. Partition data into 50% training set and 50% test set

Stratify partitioning according to los_long. For grading purpose, sort the data by subject_id, hadm_id, and stay_id and use the seed 203 for the initial data split.

library(rsample)
set.seed(203)

# sort
mimic_icu_final <- mimic_icu_final |>
  arrange(subject_id, hadm_id, stay_id)

data_split <- initial_split(
  mimic_icu_final, 
  # stratify by los_long
  strata = "los_long", 
  prop = 0.5
  )

data_split
<Training/Testing/Total>
<36590/36591/73181>
# training set
train_set <- training(data_split) |>
  select(-subject_id, -hadm_id, -stay_id)
dim(train_set)
[1] 36590    20
# testing set
test_set <- testing(data_split) |>
  select(-subject_id, -hadm_id, -stay_id)
dim(test_set)
[1] 36591    20
# recipe
icu_recipe <- recipe(los_long ~ ., data = train_set) |>
  step_dummy(all_nominal_predictors()) |>
  step_zv(all_numeric_predictors()) |> 
  print()
── Recipe ──────────────────────────────────────────────────────────────────────
── Inputs 
Number of variables by role
outcome:    1
predictor: 19
── Operations 
• Dummy variables from: all_nominal_predictors()
• Zero variance filter on: all_numeric_predictors()

3. Train and tune the models using the training set

Here I will use the logistic regression with enet regularization, random forest, boosting, and model stacking to predict whether a patient’s ICU stay will be longer than 2 days.

First I set up the cross-validation folds to be shared by all models.

# set cross-validation partitions
set.seed(203)

folds <- vfold_cv(train_set, v = 5)
folds
#  5-fold cross-validation 
# A tibble: 5 × 2
  splits               id   
  <list>               <chr>
1 <split [29272/7318]> Fold1
2 <split [29272/7318]> Fold2
3 <split [29272/7318]> Fold3
4 <split [29272/7318]> Fold4
5 <split [29272/7318]> Fold5

Then I set up the logistic regression with enet regularization, random forest, boosting, and MLP models.

3.1 Logistic regression with enet regularization

# set up the logistic regression model
logit_model <- logistic_reg(penalty = tune(), mixture = tune()) |>
  set_engine("glmnet", standardize = TRUE) |>
  print()
Logistic Regression Model Specification (classification)

Main Arguments:
  penalty = tune()
  mixture = tune()

Engine-Specific Arguments:
  standardize = TRUE

Computational engine: glmnet 
# bundle the recipe (R) and model into workflow.
logit_wf <- workflow() |>
  add_recipe(icu_recipe) |>
  add_model(logit_model) |>
  print()
══ Workflow ════════════════════════════════════════════════════════════════════
Preprocessor: Recipe
Model: logistic_reg()

── Preprocessor ────────────────────────────────────────────────────────────────
2 Recipe Steps

• step_dummy()
• step_zv()

── Model ───────────────────────────────────────────────────────────────────────
Logistic Regression Model Specification (classification)

Main Arguments:
  penalty = tune()
  mixture = tune()

Engine-Specific Arguments:
  standardize = TRUE

Computational engine: glmnet 
# tune the penalty and mixture hyperparameters
logit_grid <- grid_regular(
  penalty(range = c(-6, 3)), 
  mixture(), 
  levels = c(100, 5))

# fit CV
if (file.exists("logit_res.rds")) {
  logit_res <- read_rds("logit_res.rds")
} else {
  logit_res <- 
    tune_grid(
      object = logit_wf, 
      resamples = folds, 
      grid = logit_grid,
      control = control_stack_grid()
    )
  write_rds(logit_res, "logit_res.rds")
}
logit_res
# Tuning results
# 5-fold cross-validation 
# A tibble: 5 × 5
  splits               id    .metrics             .notes           .predictions
  <list>               <chr> <list>               <list>           <list>      
1 <split [29272/7318]> Fold1 <tibble [1,000 × 6]> <tibble [0 × 3]> <tibble>    
2 <split [29272/7318]> Fold2 <tibble [1,000 × 6]> <tibble [0 × 3]> <tibble>    
3 <split [29272/7318]> Fold3 <tibble [1,000 × 6]> <tibble [0 × 3]> <tibble>    
4 <split [29272/7318]> Fold4 <tibble [1,000 × 6]> <tibble [0 × 3]> <tibble>    
5 <split [29272/7318]> Fold5 <tibble [1,000 × 6]> <tibble [0 × 3]> <tibble>    
# visualize the CV results
logit_res |>
  # aggregate metrics from K folds
  collect_metrics() |>
  print(width = Inf) |>
  filter(.metric == "roc_auc") |>
  ggplot(mapping = aes(x = penalty, y = mean, color = factor(mixture))) +
  geom_point() +
  labs(x = "Penalty", y = "CV AUC") +
  scale_x_log10()
# A tibble: 1,000 × 8
      penalty mixture .metric  .estimator  mean     n std_err
        <dbl>   <dbl> <chr>    <chr>      <dbl> <int>   <dbl>
 1 0.000001         0 accuracy binary     0.575     5 0.00351
 2 0.000001         0 roc_auc  binary     0.606     5 0.00349
 3 0.00000123       0 accuracy binary     0.575     5 0.00351
 4 0.00000123       0 roc_auc  binary     0.606     5 0.00349
 5 0.00000152       0 accuracy binary     0.575     5 0.00351
 6 0.00000152       0 roc_auc  binary     0.606     5 0.00349
 7 0.00000187       0 accuracy binary     0.575     5 0.00351
 8 0.00000187       0 roc_auc  binary     0.606     5 0.00349
 9 0.00000231       0 accuracy binary     0.575     5 0.00351
10 0.00000231       0 roc_auc  binary     0.606     5 0.00349
   .config               
   <chr>                 
 1 Preprocessor1_Model001
 2 Preprocessor1_Model001
 3 Preprocessor1_Model002
 4 Preprocessor1_Model002
 5 Preprocessor1_Model003
 6 Preprocessor1_Model003
 7 Preprocessor1_Model004
 8 Preprocessor1_Model004
 9 Preprocessor1_Model005
10 Preprocessor1_Model005
# ℹ 990 more rows

# the top 5 models
logit_res |>
  show_best("roc_auc")
# A tibble: 5 × 8
  penalty mixture .metric .estimator  mean     n std_err .config               
    <dbl>   <dbl> <chr>   <chr>      <dbl> <int>   <dbl> <chr>                 
1 0.00152    0.75 roc_auc binary     0.606     5 0.00352 Preprocessor1_Model336
2 0.00123    1    roc_auc binary     0.606     5 0.00351 Preprocessor1_Model435
3 0.00231    0.5  roc_auc binary     0.606     5 0.00352 Preprocessor1_Model238
4 0.001      1    roc_auc binary     0.606     5 0.00351 Preprocessor1_Model434
5 0.00187    0.75 roc_auc binary     0.606     5 0.00353 Preprocessor1_Model337
# select the best model
best_logit <- logit_res |>
  select_best("roc_auc")
best_logit
# A tibble: 1 × 3
  penalty mixture .config               
    <dbl>   <dbl> <chr>                 
1 0.00152    0.75 Preprocessor1_Model336
# final workflow
final_logit <- logit_wf |>
  finalize_workflow(best_logit)
final_logit
══ Workflow ════════════════════════════════════════════════════════════════════
Preprocessor: Recipe
Model: logistic_reg()

── Preprocessor ────────────────────────────────────────────────────────────────
2 Recipe Steps

• step_dummy()
• step_zv()

── Model ───────────────────────────────────────────────────────────────────────
Logistic Regression Model Specification (classification)

Main Arguments:
  penalty = 0.00151991108295293
  mixture = 0.75

Engine-Specific Arguments:
  standardize = TRUE

Computational engine: glmnet 

3.2 Random forest

# set up the random forest model
rf_mod <- 
  rand_forest(
    mode = "classification",
    # Number of predictors randomly sampled in each split
    mtry = tune(),
    # Number of trees in ensemble
    trees = tune()
  ) |> 
  set_engine("ranger", importance = "impurity")
rf_mod
Random Forest Model Specification (classification)

Main Arguments:
  mtry = tune()
  trees = tune()

Engine-Specific Arguments:
  importance = impurity

Computational engine: ranger 
# bundle the recipe (R) and model into workflow.
rf_wf <- workflow() |>
  add_recipe(icu_recipe) |>
  add_model(rf_mod)
rf_wf
══ Workflow ════════════════════════════════════════════════════════════════════
Preprocessor: Recipe
Model: rand_forest()

── Preprocessor ────────────────────────────────────────────────────────────────
2 Recipe Steps

• step_dummy()
• step_zv()

── Model ───────────────────────────────────────────────────────────────────────
Random Forest Model Specification (classification)

Main Arguments:
  mtry = tune()
  trees = tune()

Engine-Specific Arguments:
  importance = impurity

Computational engine: ranger 
# tune the number of trees and the number of features to use in each split
rf_grid <- grid_regular(
  trees(range = c(300L, 800L)), 
  mtry(range = c(1L, 5L)),
  levels = c(10, 10)
  )

# fit CV
if (file.exists("rf_res.rds")) {
  rf_res <- read_rds("rf_res.rds")
} else {
  rf_res <- 
  tune_grid(
    object = rf_wf, 
    resamples = folds, 
    grid = rf_grid,
    control = control_stack_grid()
  )
  write_rds(rf_res, "rf_res.rds")
}
rf_res
# Tuning results
# 5-fold cross-validation 
# A tibble: 5 × 5
  splits               id    .metrics           .notes           .predictions
  <list>               <chr> <list>             <list>           <list>      
1 <split [29272/7318]> Fold1 <tibble [100 × 6]> <tibble [0 × 3]> <tibble>    
2 <split [29272/7318]> Fold2 <tibble [100 × 6]> <tibble [0 × 3]> <tibble>    
3 <split [29272/7318]> Fold3 <tibble [100 × 6]> <tibble [0 × 3]> <tibble>    
4 <split [29272/7318]> Fold4 <tibble [100 × 6]> <tibble [0 × 3]> <tibble>    
5 <split [29272/7318]> Fold5 <tibble [100 × 6]> <tibble [0 × 3]> <tibble>    
# visualize CV results
rf_res |>
  collect_metrics() |>
  print(width = Inf) |>
  filter(.metric == "roc_auc") |>
  ggplot(mapping = aes(x = trees, y = mean, color = factor(mtry))) +
  geom_point() + 
  # geom_line() + 
  labs(x = "Num. of Trees", y = "CV AUC")
# A tibble: 100 × 8
    mtry trees .metric  .estimator  mean     n std_err .config              
   <int> <int> <chr>    <chr>      <dbl> <int>   <dbl> <chr>                
 1     1   300 accuracy binary     0.574     5 0.00554 Preprocessor1_Model01
 2     1   300 roc_auc  binary     0.615     5 0.00401 Preprocessor1_Model01
 3     1   355 accuracy binary     0.572     5 0.00477 Preprocessor1_Model02
 4     1   355 roc_auc  binary     0.614     5 0.00377 Preprocessor1_Model02
 5     1   411 accuracy binary     0.572     5 0.00462 Preprocessor1_Model03
 6     1   411 roc_auc  binary     0.615     5 0.00335 Preprocessor1_Model03
 7     1   466 accuracy binary     0.572     5 0.00540 Preprocessor1_Model04
 8     1   466 roc_auc  binary     0.615     5 0.00366 Preprocessor1_Model04
 9     1   522 accuracy binary     0.572     5 0.00509 Preprocessor1_Model05
10     1   522 roc_auc  binary     0.615     5 0.00401 Preprocessor1_Model05
# ℹ 90 more rows

# the top 5 models
rf_res |>
  show_best("roc_auc")
# A tibble: 5 × 8
   mtry trees .metric .estimator  mean     n std_err .config              
  <int> <int> <chr>   <chr>      <dbl> <int>   <dbl> <chr>                
1     2   688 roc_auc binary     0.622     5 0.00357 Preprocessor1_Model18
2     3   800 roc_auc binary     0.621     5 0.00359 Preprocessor1_Model30
3     4   744 roc_auc binary     0.621     5 0.00360 Preprocessor1_Model39
4     3   744 roc_auc binary     0.621     5 0.00396 Preprocessor1_Model29
5     4   633 roc_auc binary     0.621     5 0.00387 Preprocessor1_Model37
# the best model
best_rf <- rf_res |>
  select_best("roc_auc")
best_rf
# A tibble: 1 × 3
   mtry trees .config              
  <int> <int> <chr>                
1     2   688 Preprocessor1_Model18
# final workflow
final_rf <- rf_wf |>
  finalize_workflow(best_rf)
final_rf
══ Workflow ════════════════════════════════════════════════════════════════════
Preprocessor: Recipe
Model: rand_forest()

── Preprocessor ────────────────────────────────────────────────────────────────
2 Recipe Steps

• step_dummy()
• step_zv()

── Model ───────────────────────────────────────────────────────────────────────
Random Forest Model Specification (classification)

Main Arguments:
  mtry = 2
  trees = 688

Engine-Specific Arguments:
  importance = impurity

Computational engine: ranger 

3.3 Boosting

# set up the boosting model
gb_mod <- 
  boost_tree(
    mode = "classification",
    trees = 1000, 
    tree_depth = tune(),
    learn_rate = tune()
  ) |> 
  set_engine("xgboost")
gb_mod
Boosted Tree Model Specification (classification)

Main Arguments:
  trees = 1000
  tree_depth = tune()
  learn_rate = tune()

Computational engine: xgboost 
# bundle the recipe (R) and model into workflow.
gb_wf <- workflow() |>
  add_recipe(icu_recipe) |>
  add_model(gb_mod)
gb_wf
══ Workflow ════════════════════════════════════════════════════════════════════
Preprocessor: Recipe
Model: boost_tree()

── Preprocessor ────────────────────────────────────────────────────────────────
2 Recipe Steps

• step_dummy()
• step_zv()

── Model ───────────────────────────────────────────────────────────────────────
Boosted Tree Model Specification (classification)

Main Arguments:
  trees = 1000
  tree_depth = tune()
  learn_rate = tune()

Computational engine: xgboost 
# tune the tree depth and learning rate
gb_grid <- grid_regular(
  tree_depth(range = c(1L, 5L)),
  learn_rate(range = c(-4, 1), trans = log10_trans()),
  levels = c(3, 10)
  )

# fit CV
if (file.exists("gb_res.rds")) {
  gb_res <- read_rds("gb_res.rds")
} else {
  gb_res <- tune_grid(
    object = gb_wf,
    resamples = folds,
    grid = gb_grid,
    control = control_stack_grid()
    )
  write_rds(gb_res, "gb_res.rds")
}
gb_res
# Tuning results
# 5-fold cross-validation 
# A tibble: 5 × 5
  splits               id    .metrics          .notes           .predictions
  <list>               <chr> <list>            <list>           <list>      
1 <split [29272/7318]> Fold1 <tibble [60 × 6]> <tibble [0 × 3]> <tibble>    
2 <split [29272/7318]> Fold2 <tibble [60 × 6]> <tibble [0 × 3]> <tibble>    
3 <split [29272/7318]> Fold3 <tibble [60 × 6]> <tibble [0 × 3]> <tibble>    
4 <split [29272/7318]> Fold4 <tibble [60 × 6]> <tibble [0 × 3]> <tibble>    
5 <split [29272/7318]> Fold5 <tibble [60 × 6]> <tibble [0 × 3]> <tibble>    
# visualize the CV results
gb_res |>
  collect_metrics() |>
  print(width = Inf) |>
  filter(.metric == "roc_auc") |>
  ggplot(mapping = aes(x = learn_rate, y = mean, color = factor(tree_depth))) +
  geom_point() +
  labs(x = "Learning Rate", y = "CV AUC") +
  scale_x_log10()
# A tibble: 60 × 8
   tree_depth learn_rate .metric  .estimator  mean     n std_err
        <int>      <dbl> <chr>    <chr>      <dbl> <int>   <dbl>
 1          1   0.0001   accuracy binary     0.539     5 0.00356
 2          1   0.0001   roc_auc  binary     0.547     5 0.00679
 3          3   0.0001   accuracy binary     0.548     5 0.00378
 4          3   0.0001   roc_auc  binary     0.572     5 0.00292
 5          5   0.0001   accuracy binary     0.561     5 0.00398
 6          5   0.0001   roc_auc  binary     0.585     5 0.00283
 7          1   0.000359 accuracy binary     0.542     5 0.00445
 8          1   0.000359 roc_auc  binary     0.561     5 0.00277
 9          3   0.000359 accuracy binary     0.558     5 0.00403
10          3   0.000359 roc_auc  binary     0.585     5 0.00423
   .config              
   <chr>                
 1 Preprocessor1_Model01
 2 Preprocessor1_Model01
 3 Preprocessor1_Model02
 4 Preprocessor1_Model02
 5 Preprocessor1_Model03
 6 Preprocessor1_Model03
 7 Preprocessor1_Model04
 8 Preprocessor1_Model04
 9 Preprocessor1_Model05
10 Preprocessor1_Model05
# ℹ 50 more rows

# the top 5 models
gb_res |>
  show_best("roc_auc")
# A tibble: 5 × 8
  tree_depth learn_rate .metric .estimator  mean     n std_err .config          
       <int>      <dbl> <chr>   <chr>      <dbl> <int>   <dbl> <chr>            
1          5    0.00464 roc_auc binary     0.620     5 0.00254 Preprocessor1_Mo…
2          3    0.0167  roc_auc binary     0.620     5 0.00260 Preprocessor1_Mo…
3          5    0.0167  roc_auc binary     0.618     5 0.00244 Preprocessor1_Mo…
4          3    0.00464 roc_auc binary     0.617     5 0.00266 Preprocessor1_Mo…
5          3    0.0599  roc_auc binary     0.616     5 0.00283 Preprocessor1_Mo…
# the best model
best_gb <- gb_res |>
  select_best("roc_auc")
best_gb
# A tibble: 1 × 3
  tree_depth learn_rate .config              
       <int>      <dbl> <chr>                
1          5    0.00464 Preprocessor1_Model12
# final workflow
final_gb <- gb_wf |>
  finalize_workflow(best_gb)
final_gb
══ Workflow ════════════════════════════════════════════════════════════════════
Preprocessor: Recipe
Model: boost_tree()

── Preprocessor ────────────────────────────────────────────────────────────────
2 Recipe Steps

• step_dummy()
• step_zv()

── Model ───────────────────────────────────────────────────────────────────────
Boosted Tree Model Specification (classification)

Main Arguments:
  trees = 1000
  tree_depth = 5
  learn_rate = 0.00464158883361278

Computational engine: xgboost 

3.4 Model stacking

# build the stacked ensemble
if (file.exists("icu_model_st.rds")) {
  icu_model_st <- read_rds("icu_model_st.rds")
} else {
  icu_model_st <- 
    # initialize the stack
    stacks() |>
    # add candidate members
    add_candidates(logit_res) |>
    add_candidates(rf_res) |>
    add_candidates(gb_res) |>
    # determine how to combine their predictions
    blend_predictions(
      penalty = 10^(-6:2),
      metrics = c("roc_auc")
      ) |>
    # fit the candidates with nonzero stacking coefficients
    fit_members()
  write_rds(icu_model_st, "icu_model_st.rds")
}
icu_model_st
── A stacked ensemble model ─────────────────────────────────────


Out of 217 possible candidate members, the ensemble retained 22.

Penalty: 0.001.

Mixture: 1.


The 10 highest weighted member classes are:
# A tibble: 10 × 3
   member                     type         weight
   <chr>                      <chr>         <dbl>
 1 .pred_TRUE_gb_res_1_12     boost_tree    0.873
 2 .pred_TRUE_rf_res_1_37     rand_forest   0.717
 3 .pred_TRUE_gb_res_1_17     boost_tree    0.504
 4 .pred_TRUE_rf_res_1_15     rand_forest   0.502
 5 .pred_TRUE_rf_res_1_39     rand_forest   0.460
 6 .pred_TRUE_rf_res_1_18     rand_forest   0.411
 7 .pred_TRUE_rf_res_1_48     rand_forest   0.409
 8 .pred_TRUE_logit_res_1_101 logistic_reg  0.337
 9 .pred_TRUE_rf_res_1_33     rand_forest   0.215
10 .pred_TRUE_logit_res_1_301 logistic_reg  0.173
# plot the results
autoplot(icu_model_st)

# show the relationship more directly
autoplot(icu_model_st, type = "members")

# see the top results
autoplot(icu_model_st, type = "weights")

4. Compare model classification performance on the test set

Report both the area under ROC curve and accuracy for each machine learning algorithm and the model stacking. Interpret the results. What are the most important features in predicting long ICU stays? How do the models compare in terms of performance and interpretability?

4.1 AUC and accuracy of each model

4.1.1 Logistic regression with enet regularization

I fit the whole training set and predict the test cases based on logistic regression. Then I compute the test metrics as follows.

  • The AUC of the logistic regression model is 0.600. This means that there is a 60% chance that the model will be able to distinguish between a randomly chosen positive instance and a randomly chosen negative instance.

  • The accuracy is 0.573, which means that 57.3% of the logistic regression model’s predictions are correct.

# Fit the whole training set, then predict the test cases based on logistic regression
logit_fit <- 
  final_logit |>
  last_fit(data_split)

# Test metrics
logit_fit |> 
  collect_metrics()
# A tibble: 2 × 4
  .metric  .estimator .estimate .config             
  <chr>    <chr>          <dbl> <chr>               
1 accuracy binary         0.573 Preprocessor1_Model1
2 roc_auc  binary         0.600 Preprocessor1_Model1

4.1.2 Random forest

I fit the whole training set and predict the test cases based on random forest. Then I compute the test metrics as follows.

  • The AUC of the random forest model is 0.620. This means that there is a 62% chance that the model will be able to distinguish between a randomly chosen positive instance and a randomly chosen negative instance.

  • The accuracy is 0.586, which means that 58.6% of the random forest model’s predictions are correct.

# Fit the whole training set, then predict the test cases based on RF
rf_fit <- 
  final_rf |>
  last_fit(data_split)

# Test metrics
rf_fit |> 
  collect_metrics()
# A tibble: 2 × 4
  .metric  .estimator .estimate .config             
  <chr>    <chr>          <dbl> <chr>               
1 accuracy binary         0.586 Preprocessor1_Model1
2 roc_auc  binary         0.620 Preprocessor1_Model1

4.1.3 Boosting

I fit the whole training set and predict the test cases based on boosting. Then I compute the test metrics as follows.

  • The AUC of the boosting model is 0.617. This means that there is a 61.7% chance that the model will be able to distinguish between a randomly chosen positive instance and a randomly chosen negative instance.

  • The accuracy is 0.583, which means that 58.3% of the boosting model’s predictions are correct.

# Fit the whole training set, then predict the test cases based on GB
gb_fit <- 
  final_gb |>
  last_fit(data_split)

# Test metrics
gb_fit |> 
  collect_metrics()
# A tibble: 2 × 4
  .metric  .estimator .estimate .config             
  <chr>    <chr>          <dbl> <chr>               
1 accuracy binary         0.583 Preprocessor1_Model1
2 roc_auc  binary         0.617 Preprocessor1_Model1

4.1.4 Model stacking

I predict the test cases based on model stacking by probability and by class respectively. Then I compute the AUC and accuracy as follows.

  • The AUC of the stacking model is 0.622. This means that there is a 62.2% chance that the model will be able to distinguish between a randomly chosen positive instance and a randomly chosen negative instance.

  • The accuracy is 0.586, which means that 58.6% of the stacking model’s predictions are correct.

# predict the test set by probability
icu_test_pred_st <- 
  icu_model_st |>
  predict(test_set, type = "prob") |>
  bind_cols(test_set) |>
  print()
# A tibble: 36,591 × 22
   .pred_FALSE .pred_TRUE race  insurance marital_status gender first_careunit  
         <dbl>      <dbl> <fct> <fct>     <fct>          <fct>  <fct>           
 1       0.540      0.460 WHITE Medicaid  WIDOWED        F      Medical Intensi…
 2       0.521      0.479 WHITE Other     MARRIED        F      Surgical Intens…
 3       0.648      0.352 WHITE Other     MARRIED        F      Surgical Intens…
 4       0.552      0.448 WHITE Other     MARRIED        F      Medical/Surgica…
 5       0.608      0.392 Other Medicare  SINGLE         F      Cardiac Vascula…
 6       0.531      0.469 WHITE Other     MARRIED        F      Other           
 7       0.341      0.659 WHITE Medicare  WIDOWED        F      Medical Intensi…
 8       0.445      0.555 WHITE Medicare  WIDOWED        F      Medical Intensi…
 9       0.504      0.496 WHITE Medicare  WIDOWED        M      Other           
10       0.366      0.634 Other Medicare  MARRIED        M      Other           
# ℹ 36,581 more rows
# ℹ 15 more variables: Creatinine <dbl[,1]>, Potassium <dbl[,1]>,
#   Sodium <dbl[,1]>, Chloride <dbl[,1]>, Bicarbonate <dbl[,1]>,
#   Hematocrit <dbl[,1]>, Glucose <dbl[,1]>, White_Blood_Cells <dbl[,1]>,
#   Heart_Rate <dbl[,1]>, Non_Invasive_Blood_Pressure_systolic <dbl[,1]>,
#   Non_Invasive_Blood_Pressure_diastolic <dbl[,1]>,
#   Temperature_Fahrenheit <dbl[,1]>, Respiratory_Rate <dbl[,1]>, …
# compute the AUC
st_auc <- roc_auc(icu_test_pred_st, truth = los_long, .pred_FALSE)
st_auc
# A tibble: 1 × 3
  .metric .estimator .estimate
  <chr>   <chr>          <dbl>
1 roc_auc binary         0.622
# predict the test set by class
icu_pred_class <-
  test_set |>
  select(los_long) |>
  bind_cols(
    predict(
      icu_model_st,
      test_set,
      type = "class",
      members = TRUE
      )
    ) |>
  print()
       los_long .pred_class .pred_class_logit_res_1_001
    1:    FALSE       FALSE                       FALSE
    2:    FALSE       FALSE                       FALSE
    3:    FALSE       FALSE                       FALSE
    4:    FALSE       FALSE                       FALSE
    5:    FALSE       FALSE                       FALSE
   ---                                                 
36587:    FALSE       FALSE                       FALSE
36588:    FALSE       FALSE                       FALSE
36589:    FALSE        TRUE                       FALSE
36590:     TRUE       FALSE                       FALSE
36591:     TRUE       FALSE                       FALSE
       .pred_class_logit_res_1_101 .pred_class_logit_res_1_301
    1:                       FALSE                       FALSE
    2:                       FALSE                       FALSE
    3:                       FALSE                       FALSE
    4:                       FALSE                       FALSE
    5:                       FALSE                       FALSE
   ---                                                        
36587:                       FALSE                       FALSE
36588:                       FALSE                       FALSE
36589:                       FALSE                       FALSE
36590:                       FALSE                       FALSE
36591:                       FALSE                       FALSE
       .pred_class_rf_res_1_13 .pred_class_rf_res_1_15 .pred_class_rf_res_1_18
    1:                   FALSE                   FALSE                   FALSE
    2:                   FALSE                   FALSE                   FALSE
    3:                   FALSE                   FALSE                   FALSE
    4:                   FALSE                   FALSE                   FALSE
    5:                   FALSE                   FALSE                   FALSE
   ---                                                                        
36587:                   FALSE                   FALSE                   FALSE
36588:                   FALSE                   FALSE                   FALSE
36589:                    TRUE                    TRUE                    TRUE
36590:                   FALSE                   FALSE                   FALSE
36591:                   FALSE                   FALSE                   FALSE
       .pred_class_rf_res_1_31 .pred_class_rf_res_1_33 .pred_class_rf_res_1_37
    1:                    TRUE                   FALSE                    TRUE
    2:                    TRUE                    TRUE                   FALSE
    3:                   FALSE                   FALSE                   FALSE
    4:                   FALSE                   FALSE                   FALSE
    5:                   FALSE                   FALSE                   FALSE
   ---                                                                        
36587:                   FALSE                   FALSE                   FALSE
36588:                   FALSE                   FALSE                   FALSE
36589:                    TRUE                   FALSE                    TRUE
36590:                   FALSE                   FALSE                   FALSE
36591:                   FALSE                   FALSE                   FALSE
       .pred_class_rf_res_1_38 .pred_class_rf_res_1_39 .pred_class_rf_res_1_41
    1:                   FALSE                    TRUE                    TRUE
    2:                   FALSE                   FALSE                   FALSE
    3:                   FALSE                   FALSE                   FALSE
    4:                   FALSE                   FALSE                   FALSE
    5:                   FALSE                   FALSE                   FALSE
   ---                                                                        
36587:                   FALSE                   FALSE                   FALSE
36588:                   FALSE                   FALSE                   FALSE
36589:                   FALSE                   FALSE                   FALSE
36590:                   FALSE                   FALSE                   FALSE
36591:                   FALSE                   FALSE                   FALSE
       .pred_class_rf_res_1_48 .pred_class_gb_res_1_22 .pred_class_gb_res_1_28
    1:                   FALSE                   FALSE                    TRUE
    2:                    TRUE                   FALSE                    TRUE
    3:                   FALSE                   FALSE                    TRUE
    4:                   FALSE                    TRUE                    TRUE
    5:                   FALSE                   FALSE                    TRUE
   ---                                                                        
36587:                   FALSE                    TRUE                    TRUE
36588:                   FALSE                   FALSE                    TRUE
36589:                    TRUE                    TRUE                    TRUE
36590:                   FALSE                   FALSE                   FALSE
36591:                   FALSE                   FALSE                    TRUE
       .pred_class_gb_res_1_17 .pred_class_gb_res_1_20 .pred_class_gb_res_1_26
    1:                   FALSE                   FALSE                    TRUE
    2:                   FALSE                   FALSE                    TRUE
    3:                   FALSE                   FALSE                    TRUE
    4:                    TRUE                    TRUE                   FALSE
    5:                   FALSE                   FALSE                    TRUE
   ---                                                                        
36587:                   FALSE                   FALSE                    TRUE
36588:                   FALSE                   FALSE                    TRUE
36589:                    TRUE                    TRUE                    TRUE
36590:                   FALSE                   FALSE                   FALSE
36591:                   FALSE                   FALSE                    TRUE
       .pred_class_gb_res_1_12 .pred_class_gb_res_1_15 .pred_class_gb_res_1_18
    1:                   FALSE                   FALSE                   FALSE
    2:                   FALSE                   FALSE                   FALSE
    3:                   FALSE                   FALSE                   FALSE
    4:                   FALSE                   FALSE                    TRUE
    5:                   FALSE                   FALSE                   FALSE
   ---                                                                        
36587:                   FALSE                   FALSE                   FALSE
36588:                   FALSE                   FALSE                   FALSE
36589:                    TRUE                    TRUE                   FALSE
36590:                   FALSE                   FALSE                   FALSE
36591:                   FALSE                   FALSE                   FALSE
       .pred_class_gb_res_1_27
    1:                   FALSE
    2:                    TRUE
    3:                    TRUE
    4:                    TRUE
    5:                    TRUE
   ---                        
36587:                   FALSE
36588:                    TRUE
36589:                    TRUE
36590:                    TRUE
36591:                    TRUE
# compute the accuracy
accuracy <- map(
  colnames(icu_pred_class),
  ~mean(icu_pred_class$los_long == pull(icu_pred_class, .x))
  ) |>
  set_names(colnames(icu_pred_class)) |>
  as_tibble() |>
  pivot_longer(c(everything(), -los_long), names_to = "model", values_to = "accuracy") |>
  # sort the accuracy
  arrange(desc(accuracy)) |>
  print()
# A tibble: 23 × 3
   los_long model                   accuracy
      <dbl> <chr>                      <dbl>
 1        1 .pred_class_rf_res_1_37    0.586
 2        1 .pred_class                0.585
 3        1 .pred_class_rf_res_1_38    0.585
 4        1 .pred_class_rf_res_1_18    0.585
 5        1 .pred_class_rf_res_1_13    0.584
 6        1 .pred_class_rf_res_1_48    0.584
 7        1 .pred_class_rf_res_1_15    0.584
 8        1 .pred_class_gb_res_1_15    0.584
 9        1 .pred_class_rf_res_1_31    0.583
10        1 .pred_class_gb_res_1_12    0.583
# ℹ 13 more rows
# save the accuracy of stacking model
st_accuracy <- accuracy$accuracy[1]
st_accu <- data.frame(
  .metric = "accuracy",
  .estimator = "binary",
  .estimate = st_accuracy
)

4.2 Feature importance

Here I plot the feature importance of the random forest and boosting models as follows. From the 2 plots, we can see that the most important features in predicting long ICU stays are age_intime, Heart_Rate, Non_Invasive_Blood_Pressure_systolic, White_Blood_Cells, Glucose, and Hematocrit. Among these features, age_intime is the patient’s age at ICU intime, Heart_Rate and Non_Invasive_Blood_Pressure_systolic are the first vital measurements during ICU stay, and White_Blood_Cells, Glucose and Hematocrit are the last lab measurements before the ICU stay.

# feature importance of RF
fitted_wf <- rf_fit$.workflow[[1]]
fitted_model <- extract_fit_engine(fitted_wf)
importance <- fitted_model$variable.importance
importance_rf <- data.frame(
  Feature = names(importance),
  Importance = importance,
  stringsAsFactors = FALSE
)
importance_rf <- importance_rf[order(importance_rf$Importance, decreasing = TRUE), ]

# plot the feature importance
p <- ggplot(importance_rf, aes(x = reorder(Feature, Importance), y = Importance)) +
  geom_col() +
  coord_flip() +
  geom_text(aes(label = round(Importance, 2)), hjust = -0.1) +
  labs(x = "Features", y = "Importance") +
  theme(panel.background = element_rect(fill = "white"),
        panel.grid = element_line(color = "gray", linewidth = 0.2),
        panel.border = element_rect(fill = NA, linewidth = 0.5), 
        # increase the font size
        axis.text = element_text(size = 12),
        axis.title = element_text(size = 15)
        )

ggsave("feature_importance_rf.png", p, width = 18, height = 8)

Feature Importance of RF
# feature importance of GB
fitted_gbwf <- gb_fit$.workflow[[1]]
fitted_gbmodel <- extract_fit_engine(fitted_gbwf)
importance_gb <- xgb.importance(model = fitted_gbmodel)

p <- ggplot(importance_gb, aes(x = reorder(Feature, Gain), y = Gain)) +
  geom_col() +
  coord_flip() +
  geom_text(aes(label = round(Gain, 2)), hjust = -0.1) +
  labs(x = "Features", y = "Gain") +
  theme(panel.background = element_rect(fill = "white"),
        panel.grid = element_line(color = "gray", linewidth = 0.2),
        panel.border = element_rect(fill = NA, linewidth = 0.5),
        # increase the font size
        axis.text = element_text(size = 12),
        axis.title = element_text(size = 15))

ggsave("feature_importance_gb.png", p, width = 18, height = 8)

Feature Importance of XGB

4.3 Model comparison in terms of performance and interpretability

4.3.1 Performance (stacking > RF > XGB > logit)

From the performance table and ROC curves below, we can see that

  • Model Stacking shows the best performance in terms of both accuracy and ROC AUC, slightly outperforming the other models. This indicates that combining the predictions of multiple models can lead to a slight improvement in predictive performance.

  • Random Forest and Boosting models perform similarly and better than logistic regression, suggesting that tree-based methods may capture complex patterns in the data more effectively.

  • Logistic Regression shows the lowest performance among the models in both metrics, which might be due to its linear nature being less capable of handling complex relationships in the data.

# compare the models in terms of performance
library(broom)
metrics_list <- list(
  logit = logit_fit %>% collect_metrics(),
  rf = rf_fit %>% collect_metrics(),
  boosting = gb_fit %>% collect_metrics(),
  stacking = st_accu, 
  stacking = st_auc
)

# tabulate the accuracy and roc_auc results
bind_rows(metrics_list, .id = "model") |>
  filter(.metric == "accuracy" | .metric == "roc_auc") |>
  select(-.estimator, -.config) |>
  pivot_wider(names_from = .metric, values_from = .estimate) |>
  print()
# A tibble: 4 × 3
  model    accuracy roc_auc
  <chr>       <dbl>   <dbl>
1 logit       0.573   0.600
2 rf          0.586   0.620
3 boosting    0.583   0.617
4 stacking    0.586   0.622
# plot the ROC curves
library(yardstick)
logit_preds <- logit_fit |> collect_predictions()
rf_preds <- rf_fit |> collect_predictions()
gb_preds <- gb_fit |> collect_predictions()
logic_roc <- logit_preds |> roc_curve(los_long, .pred_FALSE)
rf_roc <- rf_preds |> roc_curve(los_long, .pred_FALSE)
gb_roc <- gb_preds |> roc_curve(los_long, .pred_FALSE)
st_roc <- icu_test_pred_st |> roc_curve(los_long, .pred_FALSE)
roc_curve <- bind_rows(
  logic_roc %>% mutate(model = "logit"),
  rf_roc %>% mutate(model = "rf"), 
  gb_roc %>% mutate(model = "XGB"), 
  st_roc %>% mutate(model = "stacking"))
roc_curve |> ggplot(aes(x = 1 - specificity, y = sensitivity, color = model)) +
  geom_line() +
  geom_abline(slope = 1, intercept = 0, linetype = "dashed") +
  labs(title = "ROC Curves of Different Models",
       x = "False Positive Rate",
       y = "True Positive Rate") +
  theme(panel.background = element_rect(fill = "white"),
        panel.grid = element_line(color = "gray", linewidth = 0.2),
        panel.border = element_rect(fill = NA, linewidth = 0.5))

4.3.2 Interpretability (logit > rf > XGB > stacking)

  • Logistic Regression with ENet Regularization: High. Logistic regression models provide coefficients for each feature, making it straightforward to understand the impact of each feature on the prediction. Elastic Net regularization, which combines L1 and L2 penalties, can further enhance interpretability by promoting sparsity and reducing the influence of less important features.

  • Random Forest: Moderate. While random forests offer insights into feature importance, indicating which features are most influential in making predictions, the ensemble nature of the model (comprising many decision trees) makes it harder to trace the exact decision path for specific predictions.

  • Boosting Models: Moderate to Low. Similar to random forests, boosting models can provide measures of feature importance. However, the sequential correction of errors in boosting adds complexity, making the exact reasoning behind predictions less transparent than simpler models.

  • Model Stacking: Low. Stacking involves combining the predictions from multiple models, which inherently reduces interpretability. Understanding how individual predictions contribute to the final decision can be challenging, as it depends on the interplay between different base models and possibly a meta-model.

4.3.3 Conclusions on model comparison

In summary, the choice of model should consider the trade-off between performance and interpretability. Logistic regression offers high interpretability but lower performance, while random forests and boosting models provide a good balance between the two. Model stacking, while showing a slight performance advantage, sacrifices interpretability due to its complex nature. Therefore, the choice of model should be based on the specific needs of the application, considering the importance of interpretability, the complexity of the data, and the desired level of predictive performance.

5. Conclusions

In this homework, I developed machine learning approaches to predict whether a patient’s ICU stay will be longer than 2 days using the MIMIC-IV ICU cohort. First, I preprocessed the data by imputing missing values/outliers and correcting data types. Then, I partitioned the data into 50% training set and 50% test set, and trained and tuned logistic regression with enet regularization, random forest, boosting models, plus a model stacking approach. At last, I compared the classification performance of the models on the test set, identified the most important features in predicting long ICU stays, and evaluated the models in terms of performance and interpretability. The results showed that model stacking slightly outperformed the other models in terms of accuracy and ROC AUC, while logistic regression provided the highest interpretability. age_intime, Heart_Rate, Non_Invasive_Blood_Pressure_systolic, White_Blood_Cells, Glucose, and Hematocrit were identified as the most important features by RF and Boosting models in predicting long ICU stays.